{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Code for **\"Flash/No Flash\"** figure. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Import libs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import print_function\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "import os\n",
    "#os.environ['CUDA_VISIBLE_DEVICES'] = '3'\n",
    "\n",
    "import numpy as np\n",
    "from models import *\n",
    "\n",
    "import torch\n",
    "import torch.optim\n",
    "\n",
    "from utils.denoising_utils import *\n",
    "from utils.sr_utils import load_LR_HR_imgs_sr\n",
    "torch.backends.cudnn.enabled = True\n",
    "torch.backends.cudnn.benchmark =True\n",
    "dtype = torch.cuda.FloatTensor\n",
    "\n",
    "imsize =-1\n",
    "PLOT = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "imgs = load_LR_HR_imgs_sr('data/flash_no_flash/cave01_00_flash.jpg', -1, 1, enforse_div32='CROP')\n",
    "img_flash = load_LR_HR_imgs_sr('data/flash_no_flash/cave01_00_flash.jpg', -1, 1, enforse_div32='CROP')['HR_pil']\n",
    "img_flash_np = pil_to_np(img_flash)\n",
    "\n",
    "img_noflash = load_LR_HR_imgs_sr('data/flash_no_flash/cave01_01_noflash.jpg', -1, 1, enforse_div32='CROP')['HR_pil']\n",
    "img_noflash_np = pil_to_np(img_noflash)\n",
    "\n",
    "g = plot_image_grid([img_flash_np, img_noflash_np],3,12)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pad = 'reflection'\n",
    "OPT_OVER = 'net'\n",
    "\n",
    "num_iter = 601\n",
    "LR = 0.1 \n",
    "OPTIMIZER = 'adam'\n",
    "reg_noise_std = 0.0\n",
    "show_every = 50\n",
    "figsize = 6\n",
    "\n",
    "# We will use flash image as input\n",
    "input_depth = 3\n",
    "net_input =np_to_torch(img_flash_np).type(dtype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = skip(input_depth, 3, num_channels_down = [128, 128, 128, 128, 128], \n",
    "                        num_channels_up   = [128, 128, 128, 128, 128],\n",
    "                        num_channels_skip = [4, 4, 4, 4, 4], \n",
    "                        upsample_mode=['nearest', 'nearest', 'bilinear', 'bilinear', 'bilinear'], \n",
    "                        need_sigmoid=True, need_bias=True, pad=pad).type(dtype)\n",
    "\n",
    "mse = torch.nn.MSELoss().type(dtype)\n",
    "\n",
    "img_flash_var = np_to_torch(img_flash_np).type(dtype)\n",
    "img_noflash_var = np_to_torch(img_noflash_np).type(dtype)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Optimize"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "net_input_saved = net_input.detach().clone()\n",
    "noise = net_input.detach().clone()\n",
    "\n",
    "\n",
    "i = 0\n",
    "def closure():\n",
    "    \n",
    "    global i, net_input\n",
    "    \n",
    "    if reg_noise_std > 0:\n",
    "        net_input = net_input_saved + (noise.normal_() * reg_noise_std)\n",
    "    \n",
    "    out = net(net_input)\n",
    "   \n",
    "    total_loss = mse(out, img_noflash_var)\n",
    "    total_loss.backward()\n",
    "        \n",
    "    print ('Iteration %05d    Loss %f' % (i, total_loss.item()), '\\r', end='')\n",
    "    if  PLOT and i % show_every == 0:\n",
    "        out_np = torch_to_np(out)\n",
    "        plot_image_grid([np.clip(out_np, 0, 1)], factor=figsize, nrow=1)\n",
    "        \n",
    "    i += 1\n",
    "\n",
    "    return total_loss\n",
    "\n",
    "p = get_params(OPT_OVER, net, net_input)\n",
    "optimize(OPTIMIZER, p, closure, LR, num_iter)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Sometimes the process stucks at reddish image, just run the code from the top one more time. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "out_np = torch_to_np(net(net_input))\n",
    "q = plot_image_grid([np.clip(out_np, 0, 1), img_noflash_np], factor=13);"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
